{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "166c742c-aa92-4c4f-93a2-876f6aa395bb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 示例：缓存及加载功能\n",
    "- 可以缓存及加载任何对象（数据、模型……）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42841ec6-63a2-40aa-9621-fc609d67a8c1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from fastphm.data.FeatureExtractor import FeatureExtractor\n",
    "from fastphm.data.labeler.BearingRulLabeler import BearingRulLabeler\n",
    "from fastphm.data.loader.bearing.XJTULoader import XJTULoader\n",
    "from fastphm.data.processor.RMSProcessor import RMSProcessor\n",
    "from fastphm.data.stage.BearingStageCalculator import BearingStageCalculator\n",
    "from fastphm.data.stage.fpt.ThreeSigmaFPTCalculator import ThreeSigmaFPTCalculator\n",
    "from fastphm.model.pytorch.base.BaseTrainer import BaseTrainer\n",
    "from fastphm.model.pytorch.basic.MLP import MLP\n",
    "from fastphm.system.Cache import Cache"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1967380-798e-4df3-bd88-b0aecd898184",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 生成数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ab1339ac-fe28-4cf4-92a8-07d3c2c71c76",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[DEBUG   13:52:59]  \n",
      "[DataLoader]  Root directory: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\n",
      "\t✓ Bearing1_1, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_1\n",
      "\t✓ Bearing1_2, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_2\n",
      "\t✓ Bearing1_3, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_3\n",
      "\t✓ Bearing1_4, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_4\n",
      "\t✓ Bearing1_5, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_5\n",
      "\t✓ Bearing2_1, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_1\n",
      "\t✓ Bearing2_2, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_2\n",
      "\t✓ Bearing2_3, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_3\n",
      "\t✓ Bearing2_4, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_4\n",
      "\t✓ Bearing2_5, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_5\n",
      "\t✓ Bearing3_1, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_1\n",
      "\t✓ Bearing3_2, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_2\n",
      "\t✓ Bearing3_3, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_3\n",
      "\t✓ Bearing3_4, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_4\n",
      "\t✓ Bearing3_5, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_5\n",
      "[INFO    13:52:59]  [DataLoader]  -> Loading data entity: Bearing1_3\n",
      "[INFO    13:53:02]  [DataLoader]  ✓ Successfully loaded: Bearing1_3\n"
     ]
    }
   ],
   "source": [
    "data_loader = XJTULoader('D:\\\\data\\\\dataset\\\\XJTU-SY_Bearing_Datasets')\n",
    "feature_extractor = FeatureExtractor(RMSProcessor(data_loader.continuum))\n",
    "fpt_calculator = ThreeSigmaFPTCalculator()\n",
    "stage_calculator = BearingStageCalculator(data_loader.continuum, fpt_calculator)\n",
    "\n",
    "# 获取原始数据、特征数据、阶段数据\n",
    "bearing = data_loader(\"Bearing1_3\", 'Horizontal Vibration')\n",
    "feature_extractor(bearing)\n",
    "stage_calculator(bearing)\n",
    "\n",
    "generator = BearingRulLabeler(2048, time_ratio=60, is_from_fpt=False, is_rectified=True)\n",
    "train_set = generator(bearing)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53c3e103-9156-47d4-966f-9952bc01183c",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 缓存数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "328c7310-aba4-4107-8fe7-0dd7bc4ae521",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[DEBUG   13:53:02]  [Cache]  Generating cache file: .\\cache\\train_set.pkl\n",
      "[DEBUG   13:53:02]  [Cache]  Generated cache file: .\\cache\\train_set.pkl\n"
     ]
    }
   ],
   "source": [
    "Cache.save(train_set, 'train_set')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b341d3a2-7ff6-4af7-8761-9c268b22a615",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8d288e84-638b-4f11-8f8a-a965ba9a75e3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[DEBUG   13:53:02]  [Cache]  -> Loading cache file: .\\cache\\train_set.pkl\n",
      "[DEBUG   13:53:02]  [Cache]  ✓ Successfully loaded: .\\cache\\train_set.pkl\n"
     ]
    }
   ],
   "source": [
    "train_set = Cache.load('train_set')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7d7aa10c-1349-4b9c-80ec-258dbf604e41",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ -0.77383518,   0.1305342 ,   0.1555443 , ...,   0.1136541 ,\n",
       "         -0.54583549,   0.42616129],\n",
       "       [ -0.38670301,  -0.3725886 ,  -0.74648857, ...,  -0.0345588 ,\n",
       "         -0.70579052,  -0.79065561],\n",
       "       [ -0.27912861,   0.52905079,  -0.2827883 , ...,  -0.63660137,\n",
       "          0.25235411,   0.45748949],\n",
       "       ...,\n",
       "       [ -0.48875809,   0.66200487,  -2.59971619, ..., -11.16659641,\n",
       "         -3.66904736,  -6.93216324],\n",
       "       [ -2.04831362,  -0.154233  ,   3.88391018, ...,  -9.79976654,\n",
       "          1.93562508,  -7.22073317],\n",
       "       [ -3.97603512,   2.60374546,  -0.52655939, ...,  -5.99588156,\n",
       "         -4.60753441, -10.23435593]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_set.x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d60e2f81-8146-4d36-aa43-fd7dc4a73c70",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 生成模型并训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "84290f34-94a4-4b34-9ad0-0d1597824b00",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO    13:53:02]  \n",
      "[Trainer]  Start training by BaseTrainer:\n",
      "\ttraining set: Bearing1_3\n",
      "\tcallbacks: []\n",
      "\tdevice: cuda\n",
      "\tdtype: torch.float32\n",
      "\tepochs: 100\n",
      "\tbatch_size: 256\n",
      "\tcriterion: MSELoss()\n",
      "\tlr: 0.01\n",
      "\tweight_decay: 0.0\n",
      "\toptimizer: Adam\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [1/100], MSELoss:2.0840\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [2/100], MSELoss:1.5193\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [3/100], MSELoss:0.7640\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [4/100], MSELoss:0.5087\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [5/100], MSELoss:0.3622\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [6/100], MSELoss:0.2742\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [7/100], MSELoss:0.2121\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [8/100], MSELoss:0.1762\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [9/100], MSELoss:0.1386\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [10/100], MSELoss:0.1104\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [11/100], MSELoss:0.0865\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [12/100], MSELoss:0.0660\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [13/100], MSELoss:0.0512\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [14/100], MSELoss:0.0395\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [15/100], MSELoss:0.0317\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [16/100], MSELoss:0.0260\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [17/100], MSELoss:0.0226\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [18/100], MSELoss:0.0209\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [19/100], MSELoss:0.0198\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [20/100], MSELoss:0.0215\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [21/100], MSELoss:0.0263\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [22/100], MSELoss:0.0259\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [23/100], MSELoss:0.0248\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [24/100], MSELoss:0.0190\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [25/100], MSELoss:0.0182\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [26/100], MSELoss:0.0196\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [27/100], MSELoss:0.0175\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [28/100], MSELoss:0.0126\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [29/100], MSELoss:0.0098\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [30/100], MSELoss:0.0074\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [31/100], MSELoss:0.0058\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [32/100], MSELoss:0.0055\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [33/100], MSELoss:0.0051\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [34/100], MSELoss:0.0046\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [35/100], MSELoss:0.0043\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [36/100], MSELoss:0.0042\n",
      "[INFO    13:53:03]  [BaseTrainer]  Epoch [37/100], MSELoss:0.0038\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [38/100], MSELoss:0.0036\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [39/100], MSELoss:0.0036\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [40/100], MSELoss:0.0038\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [41/100], MSELoss:0.0036\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [42/100], MSELoss:0.0036\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [43/100], MSELoss:0.0043\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [44/100], MSELoss:0.0052\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [45/100], MSELoss:0.0064\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [46/100], MSELoss:0.0105\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [47/100], MSELoss:0.0120\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [48/100], MSELoss:0.0134\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [49/100], MSELoss:0.0146\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [50/100], MSELoss:0.0099\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [51/100], MSELoss:0.0092\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [52/100], MSELoss:0.0090\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [53/100], MSELoss:0.0063\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [54/100], MSELoss:0.0056\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [55/100], MSELoss:0.0041\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [56/100], MSELoss:0.0035\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [57/100], MSELoss:0.0035\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [58/100], MSELoss:0.0030\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [59/100], MSELoss:0.0030\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [60/100], MSELoss:0.0028\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [61/100], MSELoss:0.0023\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [62/100], MSELoss:0.0022\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [63/100], MSELoss:0.0019\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [64/100], MSELoss:0.0018\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [65/100], MSELoss:0.0017\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [66/100], MSELoss:0.0018\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [67/100], MSELoss:0.0019\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [68/100], MSELoss:0.0016\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [69/100], MSELoss:0.0016\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [70/100], MSELoss:0.0015\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [71/100], MSELoss:0.0013\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [72/100], MSELoss:0.0011\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [73/100], MSELoss:0.0010\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [74/100], MSELoss:0.0009\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [75/100], MSELoss:0.0009\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [76/100], MSELoss:0.0008\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [77/100], MSELoss:0.0008\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [78/100], MSELoss:0.0008\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [79/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [80/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [81/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [82/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [83/100], MSELoss:0.0008\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [84/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [85/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [86/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [87/100], MSELoss:0.0007\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [88/100], MSELoss:0.0009\n",
      "[INFO    13:53:04]  [BaseTrainer]  Epoch [89/100], MSELoss:0.0009\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [90/100], MSELoss:0.0008\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [91/100], MSELoss:0.0009\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [92/100], MSELoss:0.0008\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [93/100], MSELoss:0.0011\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [94/100], MSELoss:0.0013\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [95/100], MSELoss:0.0016\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [96/100], MSELoss:0.0019\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [97/100], MSELoss:0.0020\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [98/100], MSELoss:0.0043\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [99/100], MSELoss:0.0058\n",
      "[INFO    13:53:05]  [BaseTrainer]  Epoch [100/100], MSELoss:0.0038\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'MSELoss': [2.084027421474457,\n",
       "  1.5193231225013732,\n",
       "  0.7640481412410736,\n",
       "  0.508734080195427,\n",
       "  0.36224835813045503,\n",
       "  0.27415941953659057,\n",
       "  0.21211764365434646,\n",
       "  0.17620156556367875,\n",
       "  0.1386438399553299,\n",
       "  0.11043676733970642,\n",
       "  0.0865045927464962,\n",
       "  0.0659977313131094,\n",
       "  0.05117107555270195,\n",
       "  0.03954384885728359,\n",
       "  0.031665900722146034,\n",
       "  0.026047676056623458,\n",
       "  0.022646864876151085,\n",
       "  0.020868712477385996,\n",
       "  0.01979600079357624,\n",
       "  0.021487693581730128,\n",
       "  0.026318890601396562,\n",
       "  0.02589519917964935,\n",
       "  0.024762030690908432,\n",
       "  0.019007534254342318,\n",
       "  0.018213863577693702,\n",
       "  0.01960498243570328,\n",
       "  0.01751903733238578,\n",
       "  0.012640838138759136,\n",
       "  0.00981745389290154,\n",
       "  0.007396195316687226,\n",
       "  0.005810825573280454,\n",
       "  0.005492550041526556,\n",
       "  0.005126125691458583,\n",
       "  0.00455347269307822,\n",
       "  0.004289551009424031,\n",
       "  0.004221928515471518,\n",
       "  0.003778877342119813,\n",
       "  0.00362603475805372,\n",
       "  0.0036148195154964923,\n",
       "  0.0038201405899599195,\n",
       "  0.003642614930868149,\n",
       "  0.003639403195120394,\n",
       "  0.00430621113628149,\n",
       "  0.005204852437600493,\n",
       "  0.006423478364013135,\n",
       "  0.010468336241319775,\n",
       "  0.011998951504938304,\n",
       "  0.013356059789657593,\n",
       "  0.014569108001887798,\n",
       "  0.009858300024643541,\n",
       "  0.009186186920851469,\n",
       "  0.008978211088106036,\n",
       "  0.006310665514320135,\n",
       "  0.0056049743667244915,\n",
       "  0.004102144972421229,\n",
       "  0.003480981173925102,\n",
       "  0.0035333273699507117,\n",
       "  0.0029899917426519096,\n",
       "  0.002957450761459768,\n",
       "  0.0027896881219930947,\n",
       "  0.0023249710095115004,\n",
       "  0.002200034447014332,\n",
       "  0.001859261398203671,\n",
       "  0.0018193585914559663,\n",
       "  0.0017176223103888334,\n",
       "  0.0018009327002801,\n",
       "  0.0018897187081165611,\n",
       "  0.0015976498078089207,\n",
       "  0.0016474207281135024,\n",
       "  0.0014828933053649962,\n",
       "  0.0012655390542931854,\n",
       "  0.001063406461616978,\n",
       "  0.000972395925782621,\n",
       "  0.0009403528732946142,\n",
       "  0.0008642218686873093,\n",
       "  0.0008171145309461281,\n",
       "  0.0007889578962931409,\n",
       "  0.0007879042532294989,\n",
       "  0.0007418268331093713,\n",
       "  0.000724283661111258,\n",
       "  0.0007216298778075725,\n",
       "  0.0007378013426205144,\n",
       "  0.0007669764920137823,\n",
       "  0.000712623875006102,\n",
       "  0.0007213517499621957,\n",
       "  0.0007143760187318549,\n",
       "  0.000729852548101917,\n",
       "  0.0008831887709675357,\n",
       "  0.0009348701976705342,\n",
       "  0.0008443590137176216,\n",
       "  0.000862212257925421,\n",
       "  0.0008210912434151396,\n",
       "  0.0010853734624106438,\n",
       "  0.0012928069598274305,\n",
       "  0.001616554317297414,\n",
       "  0.001945944072213024,\n",
       "  0.0020404533657711,\n",
       "  0.004281812533736229,\n",
       "  0.005796045181341469,\n",
       "  0.003832762502133846]}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer = BaseTrainer()\n",
    "model = MLP(2048,16,1)\n",
    "trainer.train(model,train_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d746f0ee-09ad-42c0-834d-28b07ad752b5",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 缓存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fb2b25c8-1e1c-4ea0-a8ed-84ef9f060ee6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[DEBUG   13:53:05]  [Cache]  Generating cache file: .\\cache\\model.pkl\n",
      "[DEBUG   13:53:05]  [Cache]  Generated cache file: .\\cache\\model.pkl\n"
     ]
    }
   ],
   "source": [
    "Cache.save(model, 'model')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a5f3640-3d0f-4ef7-a526-577fe95b6df3",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7f03508b-02f5-440b-9250-9bd405c58f89",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[DEBUG   13:53:05]  [Cache]  -> Loading cache file: .\\cache\\model.pkl\n",
      "[DEBUG   13:53:05]  [Cache]  ✓ Successfully loaded: .\\cache\\model.pkl\n"
     ]
    }
   ],
   "source": [
    "model = Cache.load('model')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
